# preselect_event_data ----------------------------------------------------
# This will make wide data format.
spread_data <- function(Data, ID, permute = FALSE, perievent.columns){
  tryCatch({
    
    flog.debug("Spreading data", name = log.name)
    flog.debug(paste0("ID: ", ID), name = log.name)
    
    if (is_empty(Data)) {
      flog.warn("Data frame empty for spreading!", name = log.name)
      return(data.frame())
      
    } else {
      
      Spread.data <- Data %>%
        select_if(function(col) !all(is.na(col))) %>% # removes NA only columns
        spread(Unit, z.score) %>%
        select_if(function(col) !all(is.na(col)))
      
      if (permute) {
        flog.debug("Permuting", name = log.name)
        Spread.data <- Spread.data %>% 
          group_by_at("TrialCnt") %>% 
          mutate_at(vars(one_of(perievent.columns)), sample) %>% ungroup()
      }
      
      stopifnot(!is.na(Spread.data))
      
      return(Spread.data)
    }
      
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data)
  }) # end try catch   
}

# Get onset indices -------------------------------------------------------
# Get event indices (preselected columns via preselect_event_data)
get_idx <- function(Data, ID, before, after, bin.size = 1, perievent.columns){
  tryCatch({
    flog.debug("Getting idx of event", name = log.name)
    return.df <- data.frame()
    flog.debug(paste0("ID: ", ID), name = log.name)
    
    if (is_empty(Data)) {
      flog.warn("Data frame empty for spreading!", name = log.name)
      return(return.df)
      
    } else {
      flog.debug("Selecting idx...", name = log.name)
      # select only event columns
      temp.data <- Data %>% select_at(vars(one_of(perievent.columns)))
      # screen for onset indexes
      idx.onsets <- map(temp.data, function(x) which(x > 0 ))
      
      Merged.data <- vector(mode = "list", length = length(names(idx.onsets)))
      
      for (column in names(idx.onsets)) {
        flog.debug(paste0("Selecting columns: ", column), name = log.name)
        
        if (length(idx.onsets[[column]]) > 0) {
          temp <- data.frame(ID = column, idx.row = idx.onsets[[column]], stringsAsFactors = FALSE)
          
          Merged.data[[column]] <- temp
        } else {
          flog.warn(paste0("Empty column, skipping: ", column))
        }
        
      }
      
      return.df <- bind_rows(Merged.data)
      
      flog.debug("Adding onset and offset", name = log.name)
      return.df <- return.df %>% 
        mutate(Onset = idx.row + before/bin.size, 
               Offset = idx.row + after/bin.size,
               EventID = seq(from = 1, to = nrow(.), by = 1)) 
      
      return(return.df)
    }
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data)
  }) # end try catch  
}

# Select events -----------------------------------------------------------
# select event (from idx table)
select_events <- function(Data.event, Idx.data, before, after, bin.size = 1) {
  tryCatch({
    
    flog.debug("Selecting events from idx", name = log.name)
    stopifnot(nrow(Data.event) > 0, nrow(Idx.data) > 0)
    
    # remove out of bound indeces
    Idx.data <- Idx.data %>%
      filter(Onset > 0 & Offset < nrow(Data.event))
    
    selected.event <- vector(mode = "list", length = nrow(Idx.data))
    
    flog.debug("Running loop for selection", name = log.name)
    for (i in 1:nrow(Idx.data)) {
      
      event.name <- Idx.data[i, "ID"]
      start.idx <- Idx.data[i, "Onset"]
      stop.idx <- Idx.data[i, "Offset"]
      
      selected.event[[i]] <- Data.event %>%
        select(starts_with("SPK")) %>%
        slice(start.idx:stop.idx) %>%
        mutate(norm.time = seq(from = before, to = after, by = bin.size),
               EventID = Idx.data[i, "EventID"],
               Event = event.name)
      
    }
    
    flog.debug("Binding selected events", name = log.name)
    selected.event.df <- bind_rows(selected.event)
    
    stopifnot(any(!is.na(selected.event.df$EventID)))
    
    selected.event.df <- selected.event.df %>%
      gather(key = Unit, value = z.score, contains("SPK"))
    
    flog.debug("Selection done.", name = log.name)
    return(selected.event.df)
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data)
  }) # end 
}
  
# Normalize perievents to baseline firing ----------------------------------------------------
# this need single unit, single event data
# Will calculate the the mean firing rate on selected baseline window AFTER calculating the mean firing
# across all event IDs (all occuring events for the given Event type)
normalize_perievents <- function(Data, min.time = -2, max.time = -1) {
  tryCatch({
    flog.info("Normalizing firing per baseline", name = log.name)
    stopifnot(min.time < max.time)
    stopifnot(nrow(Data) > 0)
    
    flog.debug("Computing z-score", name = log.name)
    
    Data.normalized <- Data %>% # was Data
      group_by(norm.time, Time.period, Event, Unit) %>% # was norm.time, Time.period
      summarise(z.score = mean(z.score, na.rm = TRUE)) %>%
      ungroup()
    
    mean <- Data.normalized %>% filter(between(norm.time, min.time, max.time)) %>%
      group_by(Event, Unit) %>% # was norm.time
      summarize(mean.z.score = mean(z.score, na.rm = TRUE), sd.z.score = sd(z.score, na.rm = TRUE)) %>%
      ungroup() 

      Data.normalized <- left_join(Data.normalized, mean, by = c("Event", "Unit"))
      
      Data.normalized <- Data.normalized %>% 
        mutate(z.score = (z.score - mean.z.score)/sd.z.score) %>%
        select(-mean.z.score, -sd.z.score) %>%
        mutate(z.score = if_else(is.infinite(z.score) | is.nan(z.score), NA_real_, z.score))

    
    return(Data.normalized)
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data)
  }) # end 
}  

# Assign firing state firing state ----------------------------------------------------
# Threhold the normalized firing rate as defined by others: if any time bin at onset or post-onset window passes
# the threshold value assign appropriate Direcion (Up/Down)

assign_threshold <- function(Data.normalized, z.threshold = 2, threshold.min.time = 0, threshold.max.time = 1) {
  tryCatch({
  
  # split the units depending on z value, also valid for session z-score
    flog.info("Thresholding", name = log.name)
    Data.thresholded <- Data.normalized %>% 
      filter(between(norm.time, threshold.min.time, threshold.max.time)) %>%
      group_by_at(c(group.columns, "Event", "Unit")) %>%
      summarise(Direction = case_when(any(is.na(z.score)) ~ "Not computed",
                                      z.score[norm.time == 0] > z.threshold ~ "Up",
                                      z.score[norm.time == 0] < -z.threshold ~ "Down",
                                      mean(z.score > z.threshold) > 0.3 ~ "Up",
                                      mean(z.score < -z.threshold) > 0.3 ~ "Down",
                                      mean(z.score > z.threshold) >  mean(z.score < -z.threshold) ~ "Up",
                                      mean(z.score > z.threshold) <  mean(z.score < -z.threshold) ~ "Down",
                                      #any(z.score > z.threshold) & any(z.score < -z.threshold) ~ "Both",
                                      any(z.score > z.threshold) ~ "Up",
                                      any(z.score < -z.threshold) ~ "Down",
                                      TRUE ~ "Nonresponder")) %>%
      ungroup(.)
    
    flog.debug("Merging columns", name = log.name)
    Data.normalized <- left_join(Data.normalized, Data.thresholded, 
                                 by = c(group.columns, "Event", "Unit"))
								 
	return(Data.normalized)
  
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data.normalized)
  }) # end 
}

# Assign firing state firing state using one.sample t.test ----------------------------------------------------
one.sample.threshold <- function(Data, 
                                 threshold.min.time = 0, threshold.max.time = 1,
                                 min.time.normalization = -2, max.time.normalization = -1,
                                 one.sample.to.zero = FALSE) {
  tryCatch({
    

    flog.debug("Thresholding by one sample t.test", name = log.name)
    
    flog.debug("Calculating mean baseline firing rate", name = log.name)
    
    #IDz <- Data %>% select(ID) %>% unique(.) %>% pull(ID)
    #print(IDz)
    
	# in case time-wise comparison is selected set mean baseline firing to zero
	# otherwise calculate the mean firing rate depending on selected time window
    if (one.sample.to.zero) {
      
      flog.debug("Setting mu to 0", name = log.name)
      mean.baseline <- 0
      
    } else {
      
      flog.debug("Calculating mean baseline firing", name = log.name)
      mean.baseline <- Data %>% 
        filter(between(norm.time, min.time.normalization, max.time.normalization)) %>%
        summarize(mean.z.score = mean(z.score, na.rm = TRUE)) %>% pull(mean.z.score)
      
    }
    
    flog.debug("Calculating mean unit firing", name = log.name)
    # Calculate the time-wise mean firing rate, collapsing all Events numbers to mean
    Data.mean <- Data %>%
      group_by(norm.time, Time.period) %>%
      summarise(z.score = mean(z.score, na.rm = TRUE)) %>%
      ungroup()

    if (length(mean.baseline) != 1) {
      
      flog.error("Calculating mean baseline firing rate", name = log.name)
      return(mean.baseline)
      
    } else {
      
      flog.debug("Spreading data", name = log.name)
      
	  # spread the time from one column into separate columns in order to run through them in next step
      stat.data  <- data.frame(T.val = NA,
                               p.val = NA,
                               time = NA)
      
      Data.spread <- Data %>% 
        filter(between(norm.time, threshold.min.time, threshold.max.time)) %>% 
        spread(key = norm.time, value = z.score) %>% select(-ID, -EventID, -Time.period)
      #%>%
      #  select_if(function(col) n_distinct(col) > 2)

        if (nrow(Data.spread) > 2 & ncol(Data.spread) >= 1) {
		
          # run series of t.tests
          flog.debug("Calculating t.test", name = log.name)
          t.test_safe <- possibly(t.test, NA_real_)
          
          p.val <- map(Data.spread , t.test_safe, mu = mean.baseline)
          
          stat.data <- data.frame(
            T.val = sapply(p.val, possibly(function(x) as.numeric(x$statistic), NA_real_)),
            p.val = sapply(p.val, possibly(function(x) as.numeric(x$p.value), NA_real_)),
            norm.time = as.numeric(names(p.val)), stringsAsFactors = FALSE)

          flog.debug("Thresholding", name = log.name)
          
		  # if comparison to 0 is selected threhold each time bin separatly
          if (one.sample.to.zero) {
            
            Data.mean <- left_join(Data.mean, stat.data, by = "norm.time")
            
            Data.mean <- Data.mean %>% 
              mutate(Direction = case_when(is.na(p.val) ~ "NA in P value", # was Not computed
                                           is.na(T.val) ~ "NA in T value",# was Not computed
                                           T.val > 0 & p.val < 0.05 ~ "Up",
                                           T.val < 0 & p.val < 0.05 ~ "Down",
                                           TRUE ~ "Nonresponder")) %>%
              select(-T.val, -p.val)
           
		  # if threholding needs to be done comparing to mean baseline firing rate to post time bins, separate all time bins into one threshold
          } else {
            
			# add value in case the column at time point 0 is missing
            if (length(stat.data[stat.data$norm.time == "0"]) == 0) {
              
              flog.debug("Missing 0 time point, adding p.value = 1", name = log.name)
              stat.data <- bind_rows(stat.data, data.frame(norm.time = 0, 
                                                           T.val = 0, 
                                                           p.val = 1, 
                                                           stringsAsFactors = FALSE))
              
            }
            
            n.distinct <- Data %>% 
              filter(between(norm.time, threshold.min.time, threshold.max.time)) %>%
              group_by(norm.time) %>% summarise(n = n_distinct(z.score)) %>% 
              ungroup() %>%
              filter(n < 2) %>% pull(norm.time)
            
            if (length(n.distinct) > 0) {
              
              stat.data <- stat.data %>% mutate(T.val = if_else(norm.time %in% n.distinct, 0, T.val),
                                                p.val = if_else(norm.time %in% n.distinct, 1, p.val))
              
            }
            
            
            
            stat.data <- stat.data %>%
              summarise(Direction = case_when(all(is.na(p.val)) ~ "All NA in P value", # was Not computed
                                              all(is.na(T.val)) ~ "All NA in T value", # was Not computed
                                              T.val[norm.time == 0] > 0 & p.val[norm.time == 0] < 0.05 ~ "Up",
                                              T.val[norm.time == 0] < 0 & p.val[norm.time == 0] < 0.05 ~ "Down",
                                              any(T.val > 0 & p.val < 0.05) ~ "Up",
                                              any(T.val < 0 & p.val < 0.05) ~ "Down",
                                              all(p.val == 1) ~ "Cannot compute t.test", 
                                              TRUE ~ "Nonresponder"))
            
            
            Data.mean <- Data.mean %>% mutate(Direction = stat.data$Direction)
            
          }
          
          return(Data.mean)
        
		# in case errors occurs write as not computed
        } else {
          
          flog.debug("Cannot calculate t.test", name = log.name)
          
          Data.mean <- Data.mean %>% mutate(Direction = "To few trials")  # was Not computed
          
          return(Data.mean)
          
      }
      

    }

  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data.mean)
  }) # end 
}




# Auxilary modules ----------------------------------------------------------------------------------------------
# Binds the above functions in order to get the perievent data frame and indeces data frame.
get_perievents <- function(Data, group.columns, protocol.columns, perievent.columns,
                           before = -5, after = 5,
                           bin.size = 0.25, permute = FALSE){
  tryCatch({
    flog.info("Getting idx table and selecting events...", name = log.name)
    
    flog.debug("Filtering table and processing", name = log.name)
    
    Data.event <- Data %>%
      select_at(c(group.columns, unit.column, protocol.columns, perievent.columns)) %>%
      group_by_at(group.columns) %>%
      nest() %>%
      ungroup() %>% 
      unite(ID, group.columns, remove = FALSE) %>%
      mutate(spread.data = future_map2(data, ID, spread_data, permute, perievent.columns)) %>% #change for furrr
      select(-data) %>%
      mutate(idx = future_map2(spread.data, ID, get_idx,  #change for furrr
                       before = before, after = after, bin.size = bin.size, 
                       perievent.columns = perievent.columns)) %>%
      mutate(events = future_map2(spread.data, idx, select_events,  #change for furrr
                           before, after, bin.size = bin.size)) 
    
    flog.debug("Making Idx and event tables...", name = log.name)
    Idx <- Data.event %>% 
      select(-events, -spread.data) %>%
      unnest()
    
    flog.debug("Adding Time period", name = log.name)
    Data.event <- Data.event %>%
      select(-idx, -spread.data) %>%
      unnest() %>%
      mutate(Time.period = if_else(norm.time < 0, "Pre_Event_onset", "Post_Event_onset"))
    
    flog.info("Idx done...", name = log.name)
    return(Data.out <- list(Idx = Idx, Data.event = Data.event))
    
  }, error = function(e) {
    flog.error("ERROR!", name = log.name)
    flog.error(e, name = log.name)
  })
}

# Binds together the above function to threshold Unit firing.
threshold_units <- function(Data, group.columns, 
                            normalize.spikes.per.event, min.time.normalization, max.time.normalization,
                            z.threshold, threshold.min.time, threshold.max.time,
                            one.sample.t.threshold, one.sample.to.zero){
  tryCatch({
    
    flog.info("Initializing thresholding", name = log.name)
    #This needs to include one sample t.test for 0 to 1 sec where mu = mean(zscore for -2 to -1 sec)
    if (normalize.spikes.per.event) {
	    # If needed normalize perievent per baseline selected
      # normalize the perievent firing
      flog.info("Normalizing the perievent firing to baseline", name = log.name)
      Data.normalized <- Data %>% 
        group_by_at(c(group.columns)) %>% # All events for given session/treatment (was "Event", "Unit")
        nest() %>%
        mutate(normalize = future_map(data, 
                                      normalize_perievents, 
                                      min.time.normalization, max.time.normalization)) %>% 
        select(-data) %>% 
        unnest() %>% ungroup()
      
	    flog.debug("Assigning z-thresholding to perievent normalized data", name = log.name)
	
	    Data.normalized <- assign_threshold(Data.normalized, 
	                                        z.threshold, 
	                                        threshold.min.time, threshold.max.time)
	    
	    return(Data.normalized)
	  
    } else if (one.sample.t.threshold) {
	
	    flog.info("Calculating one sample t.test ", name = log.name)
	    
      Data.normalized <- Data %>% 
	      group_by_at(c(group.columns,"Event", "Unit")) %>% #single unit data into nest
	      nest() %>% # this would need single unit data into nest
	      mutate(Direction = future_map(data,  #future_map
	                                    one.sample.threshold,
	                                    threshold.min.time = threshold.min.time, 
	                                    threshold.max.time = threshold.max.time,
	                                    min.time.normalization = min.time.normalization, 
	                                    max.time.normalization = max.time.normalization,
	                                    one.sample.to.zero = one.sample.to.zero)) %>% 
	      select(-data) %>% 
	      unnest() %>% ungroup() 
	  
		  return(Data.normalized)
		
    } else {# not normalizing and not using t.test thresholding 
	    flog.info("Calculating mean", name = log.name)
      Data.normalized <- Data %>% 
        group_by_at(c(group.columns, "Event", "Unit", "norm.time", "Time.period")) %>%
        summarise(z.score = mean(z.score, na.rm = TRUE)) %>% ungroup()
		
	    flog.info("Assigning z thresholding to non normalized data", name = log.name)
	    Data.normalized <- assign_threshold(Data.normalized, 
	                                        z.threshold, 
	                                        threshold.min.time, threshold.max.time)
	    return(Data.normalized)
	}
    
    # split the units depending on z value, also valid for session z-score
    flog.info("Finished thresholding step", name = log.name)
	
    
  }, error = function(e) {
    flog.error("ERROR!", name = log.name)
    flog.error(e, name = log.name)
  })
}
